{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MNIST DCGAN\n",
    "\n",
    "We're going to create a GAN that generates synthetic handwritten digits.\n",
    "\n",
    "Code used and modifed from:\n",
    "- https://github.com/Zackory/Keras-MNIST-GAN/blob/master/mnist_dcgan.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_1 (Dense)              (None, 6272)              633472    \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_1 (LeakyReLU)    (None, 6272)              0         \n",
      "_________________________________________________________________\n",
      "reshape_1 (Reshape)          (None, 128, 7, 7)         0         \n",
      "_________________________________________________________________\n",
      "up_sampling2d_1 (UpSampling2 (None, 128, 14, 14)       0         \n",
      "_________________________________________________________________\n",
      "conv2d_1 (Conv2D)            (None, 64, 14, 14)        204864    \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_2 (LeakyReLU)    (None, 64, 14, 14)        0         \n",
      "_________________________________________________________________\n",
      "up_sampling2d_2 (UpSampling2 (None, 64, 28, 28)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 1, 28, 28)         1601      \n",
      "=================================================================\n",
      "Total params: 839,937\n",
      "Trainable params: 839,937\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d_3 (Conv2D)            (None, 64, 14, 14)        1664      \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_3 (LeakyReLU)    (None, 64, 14, 14)        0         \n",
      "_________________________________________________________________\n",
      "dropout_1 (Dropout)          (None, 64, 14, 14)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_4 (Conv2D)            (None, 128, 7, 7)         204928    \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_4 (LeakyReLU)    (None, 128, 7, 7)         0         \n",
      "_________________________________________________________________\n",
      "dropout_2 (Dropout)          (None, 128, 7, 7)         0         \n",
      "_________________________________________________________________\n",
      "flatten_1 (Flatten)          (None, 6272)              0         \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 1)                 6273      \n",
      "=================================================================\n",
      "Total params: 212,865\n",
      "Trainable params: 212,865\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from keras.layers import Input\n",
    "from keras.models import Model, Sequential\n",
    "from keras.layers.core import Reshape, Dense, Dropout, Flatten\n",
    "from keras.layers.advanced_activations import LeakyReLU\n",
    "from keras.layers.convolutional import Conv2D, UpSampling2D\n",
    "from keras.datasets import mnist\n",
    "from keras.optimizers import Adam\n",
    "from keras import backend as K\n",
    "from keras import initializers\n",
    "\n",
    "K.set_image_dim_ordering('th')\n",
    "\n",
    "# The dimensionality has been set at 100 for consistency with other GAN implementations. \n",
    "# But 10 works better here\n",
    "latent_dim = 100\n",
    "\n",
    "# Load MNIST data\n",
    "(X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
    "X_train = (X_train.astype(np.float32) - 127.5)/127.5\n",
    "X_train = X_train[:, np.newaxis, :, :]\n",
    "\n",
    "# Use Adam as the Optimizer\n",
    "adam = Adam(lr=0.0002, beta_1=0.5)\n",
    "\n",
    "# Make our Generator Model\n",
    "generator = Sequential()\n",
    "\n",
    "# Transforms the input into a 7 × 7 128-channel feature map\n",
    "generator.add(Dense(128*7*7, input_dim=latent_dim))\n",
    "generator.add(LeakyReLU(0.2))\n",
    "generator.add(Reshape((128, 7, 7)))\n",
    "generator.add(UpSampling2D(size=(2, 2)))\n",
    "generator.add(Conv2D(64, kernel_size=(5, 5), padding='same'))\n",
    "generator.add(LeakyReLU(0.2))\n",
    "generator.add(UpSampling2D(size=(2, 2)))\n",
    "\n",
    "# Produces a 28 × 28 1-channel feature map (shape of a MNIST image)\n",
    "generator.add(Conv2D(1, kernel_size=(5, 5), padding='same', activation='tanh'))\n",
    "print(generator.summary())\n",
    "generator.compile(loss='binary_crossentropy', optimizer=adam)\n",
    "\n",
    "# Make our Discriminator Model\n",
    "discriminator = Sequential()\n",
    "discriminator.add(Conv2D(64, kernel_size=(5, 5), strides=(2, 2), padding='same', \n",
    "                         input_shape=(1, 28, 28), kernel_initializer=initializers.RandomNormal(stddev=0.02)))\n",
    "discriminator.add(LeakyReLU(0.2))\n",
    "discriminator.add(Dropout(0.3))\n",
    "discriminator.add(Conv2D(128, kernel_size=(5, 5), strides=(2, 2), padding='same'))\n",
    "discriminator.add(LeakyReLU(0.2))\n",
    "discriminator.add(Dropout(0.3))\n",
    "discriminator.add(Flatten())\n",
    "discriminator.add(Dense(1, activation='sigmoid'))\n",
    "print(discriminator.summary())\n",
    "discriminator.compile(loss='binary_crossentropy', optimizer=adam)\n",
    "\n",
    "# Creating the Adversarial Network. We need to make the Discriminator weights\n",
    "# non trainable. This only applies to the GAN model.\n",
    "discriminator.trainable = False\n",
    "ganInput = Input(shape=(latent_dim,))\n",
    "x = generator(ganInput)\n",
    "ganOutput = discriminator(x)\n",
    "gan = Model(inputs=ganInput, outputs=ganOutput)\n",
    "gan.compile(loss='binary_crossentropy', optimizer=adam)\n",
    "\n",
    "# Our Discriminator and Generator Losses\n",
    "dLosses = []\n",
    "gLosses = []\n",
    "\n",
    "# Plot the loss from each batch\n",
    "def plotLoss(epoch):\n",
    "    plt.figure(figsize=(10, 8))\n",
    "    plt.plot(dLosses, label='Discriminitive loss')\n",
    "    plt.plot(gLosses, label='Generative loss')\n",
    "    plt.xlabel('Epoch')\n",
    "    plt.ylabel('Loss')\n",
    "    plt.legend()\n",
    "    plt.savefig('images/dcgan_loss_epoch_%d.png' % epoch)\n",
    "\n",
    "# Create a wall of generated MNIST images\n",
    "def plotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):\n",
    "    noise = np.random.normal(0, 1, size=[examples, latent_dim])\n",
    "    generatedImages = generator.predict(noise)\n",
    "\n",
    "    plt.figure(figsize=figsize)\n",
    "    for i in range(generatedImages.shape[0]):\n",
    "        plt.subplot(dim[0], dim[1], i+1)\n",
    "        plt.imshow(generatedImages[i, 0], interpolation='nearest', cmap='gray_r')\n",
    "        plt.axis('off')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig('images/dcgan_generated_image_epoch_%d.png' % epoch)\n",
    "\n",
    "# Save the generator and discriminator networks (and weights) for later use\n",
    "def saveModels(epoch):\n",
    "    generator.save('models/dcgan_generator_epoch_%d.h5' % epoch)\n",
    "    discriminator.save('models/dcgan_discriminator_epoch_%d.h5' % epoch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train our GAN and Plot the Synthetic Image Outputs \n",
    "\n",
    "After each consecutive Epoch we can see how synthetic images being improved "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/468 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epochs: 5\n",
      "Batch size: 128\n",
      "Batches per epoch: 468.75\n",
      "--------------- Epoch 1 ---------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 468/468 [24:11<00:00,  2.88s/it]\n",
      "  0%|          | 0/468 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------- Epoch 2 ---------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 28/468 [01:24<23:54,  3.26s/it]"
     ]
    }
   ],
   "source": [
    "epochs = 5\n",
    "batchSize = 128\n",
    "batchCount = X_train.shape[0] / batchSize\n",
    "\n",
    "print('Epochs:', epochs)\n",
    "print('Batch size:', batchSize)\n",
    "print('Batches per epoch:', batchCount)\n",
    "\n",
    "for e in range(1, epochs+1):\n",
    "    print('-'*15, 'Epoch %d' % e, '-'*15)\n",
    "    for i in tqdm(range(int(batchCount))):\n",
    "        # Get a random set of input noise and images\n",
    "        noise = np.random.normal(0, 1, size=[batchSize, latent_dim])\n",
    "        imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]\n",
    "\n",
    "        # Generate fake MNIST images\n",
    "        generatedImages = generator.predict(noise)\n",
    "        X = np.concatenate([imageBatch, generatedImages])\n",
    "\n",
    "        # Labels for generated and real data\n",
    "        yDis = np.zeros(2*batchSize)\n",
    "        # One-sided label smoothing\n",
    "        yDis[:batchSize] = 0.9\n",
    "\n",
    "        # Train discriminator\n",
    "        discriminator.trainable = True\n",
    "        dloss = discriminator.train_on_batch(X, yDis)\n",
    "\n",
    "        # Train generator\n",
    "        noise = np.random.normal(0, 1, size=[batchSize, latent_dim])\n",
    "        yGen = np.ones(batchSize)\n",
    "        discriminator.trainable = False\n",
    "        gloss = gan.train_on_batch(noise, yGen)\n",
    "\n",
    "    # Store loss of most recent batch from this epoch\n",
    "    dLosses.append(dloss)\n",
    "    gLosses.append(gloss)\n",
    "\n",
    "    if e == 1 or e % 5 == 0:\n",
    "        plotGeneratedImages(e)\n",
    "        \n",
    "        # Plot losses from every epoch\n",
    "        plotLoss(e)\n",
    "        saveModels(e)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
